Median of 2 sorted lists

Given two sorted lists, find the median of the combined lists.

Algorithm

target index

Definitions:

  • indicies are 0 indexed
  • i // 2 defines floor division such that 4 // 2 = 2 and 5 // 2 = 3

If we are given a sorted list, all we need is the length (N) to determine the median.

If our list is an odd length, median = \(\text{list}[N \; // \;\; 2]\).

If our list is an even length, median = \(\frac{\text{list}[(N \; // \; 2) - 1] + \text{list}[N \; // \;\; 2]}{2}\)

For instance:

  • if len(list) = 5, median = list[2].
  • if len(list) = 6, median = \(\frac{\text{list}[2] + \text{list}[3]}{2}\)

We define our target index as \(N \; // \;\; 2\).

High level logic

Let’s say we have two sorted lists, a and b, such that len(a) = na and len(b) = nb.

The combined list c would be \(c = (\text{a} \cup \text{b})_{sorted}\).

If we split c into two groups: c_low and c_high such that len(c_low) = N // 2, all we would need to know is max(c_low) and min(c_high). \[ \text{median} = \begin{cases} \min (\text{c\_high}), \; \text{N is odd} \\ \frac{\max(\text{c\_low}) + \min(\text{c\_high})}{2}, \; \text{N is even} \end{cases} \]

In other words, c_low and c_high don’t have to be sorted as long as we know their max and min respectively.

Turning our attention back to a and b, we can arbitrarily split a into two pieces such that len(a_low) = na_low and len(a_high) = na_high.

block-beta
  block:a
    columns 7
    
    b0_1["a"] 
    space:6

    b1_1["a_low_min"] 
    b1_2["..."] 
    b1_3["a_low_max"] 
    space 
    b1_5["a_high_min"] 
    b1_6["..."] 
    b1_7["a_high_max"]
  end

classDef background fill:#E5E4E2,stroke:#E5E4E2
classDef min fill:#89CFF0,color:#000000
classDef max fill:#7393B3,color:#000000
classDef dot color:#000000

class a,b0_1 background
class b3 max
class b5 hax
class b2,b6 dot

Next we will split b into two parts such that len(b_low) = nb_low and len(b_high) = nb_high.

We choose the split point in b such that na_low + nb_low = N // 2.

block-beta
  block:b
    columns 7
    b0["b"] space:6
    b1["b_low_min"] b2["..."] b3 ["b_low_max"] space b5["b_high_min"] b6["..."] b7["b_high_max"]
  end

classDef background fill:#E5E4E2,stroke:#E5E4E2
classDef min fill:#89CFF0,color:#000000
classDef max fill:#7393B3,color:#000000
classDef dot color:#000000

class b,b0 background
class b3 max
class b5 hax
class b2,b6 dot

\[ \text{If both of the following are true: } \begin{cases} \text{a\_low\_max} <= \text{b\_high\_min} \\ \text{b\_low\_max} <= \text{a\_high\_min} \end{cases} \tag{1} \]

\[ \text{Then we can say: } \text{a\_low} \cup \text{b\_low} <= \text{a\_high} \cup \text{b\_high} \]

\[ \text{Thus, from a set perspective: } \begin{cases} \text{a\_low} \cup \text{b\_low} = \text{c\_low} \\ \text{a\_high} \cup \text{b\_high} = \text{c\_high} \end{cases} \]

\[ \text{Therefore: } \begin{cases} \max (\text{c\_low}) = \max (\text{a\_low\_max}, \text{b\_low\_max}) \\ \min (\text{c\_high}) = \min (\text{a\_high\_min}, \text{b\_high\_min}) \\ \end{cases} \]

This means all we need to do is find the split that satisfies the split requirements (1) from above and we can calculate the overall median.

Corner cases

Conditions that might cause issues in the code:

  • a < b or b < a

Pseudo code

Given two ordered sets, a and b, without loss of generality, we can define a to be the shorter of the two. This set’s up a binary search of the shorter list and means O = log(N) where N is the length of the shorter of the two lists.

Perform a binary search of a where each iteration selects an index of a to be considered a_high_min.

For each iteration:

  • Check if the split requirements (1) have been met.
    • Requirements met:
      • Calculate median and exit
    • a_low_max > b_high_min
      • The split is too high, use the midpoint of a_low for the next iteration
    • b_low_max > a_high_min
      • The split is too low, use the midpoint of a_high for the next iteration

To handle the corner cases, we will add -inf to the low end of each list and +inf to the high end of each list. That way our split point will never be the end of the list and the list will never be empty. Also, by adding evenly we do not change the median.

Setup

Imports

Code
import numpy as np
import pandas as pd
import textwrap
from IPython.display import Markdown
from rich.console import Console
from rich.table import Table
from copy import copy, deepcopy

Implementation

Components

Code
def pad(x: list) -> list:
    """Pad the beginning and end of a list with -inf and inf respectively"""
    return [-float("inf")] + x + [float("inf")]

class Pointers():
    """Data class to hold current split target and search bounds"""
    def __init__(self, high: int, target: int, low: int):
        self.high = high
        self.target = target
        self.low = low

class OrderedList:
    """Data and methods to run and track the median search"""
    def __init__(self, nums: list, target=None):
        self.list = nums
        # pad the given list to keep targets from being outside the bounds of the list
        self.padded_list = pad(nums)
        if target == None:
            # initialize the search_idx
            self.search_idx = Pointers(low=0, target=self.n // 2, high=self.n)
        else:
            # use given target
            self.set_target(target)

    def set_target(self, target):
        """Manually set target index"""
        self.search_idx.target = target

    @property
    def n(self):
        """Return the length of the padded list"""
        return len(self.padded_list)
        
    @property
    def low_max(self):
        """Return the max value of the list below the split"""
        return self.padded_list[self.search_idx.target - 1]
        
    @property
    def high_min(self):
        """Return the min value of the list above the split"""
        return self.padded_list[self.search_idx.target]
        
    def next_target(self, search_high=True):
        """Update the search target and bounds

        We're performing a binary search so we need to track the current value (target) 
        and lower and upper bounds of the search.

        If search_high is true we will search above the target, otherwise we will 
        search below the target.

        For example we would start with the target in the middle of the list and the 
        lower and upper bounds at the beginning and end of the list respectively:

        [L--------------------T--------------------U]
        L := lower, U := upper

        search high would confine the search to the upper half:
        [---------------------L----------T---------U]

        if we next ran search_high=False we would get:
        [---------------------L-----T----U----------]
        
        We continue searching until the target, T, is the correct split.
        """
        if search_high:
            self.search_idx.low = self.search_idx.target
        else:
            self.search_idx.high = self.search_idx.target

        # set the target to half way between the low and the high bounds
        self.search_idx.target = (self.search_idx.low + self.search_idx.high) // 2

    def __str__(self):
        """Create a table representation of the OrderedList"""
        # configurable values
        # number of characters for the display of where the split value lies
        w = 20

        # derived values
        t = self.search_idx.target
        n = self.n
        
        low_width = int(t/n * w)
        low_buffer = '-' * low_width
        
        high_width = w - low_width
        high_buffer = '=' * high_width
        
        # Create table with no default borders
        table = Table(show_header=True, header_style="bold purple4")
        
        # Add columns
        table.add_column("Split", style="dodger_blue2", justify="center")
        table.add_column("max low", style="grey30", justify="right")
        table.add_column("min high", style="grey11")
        table.add_column("Split Location", style="dodger_blue2", justify="left")
        table.add_column("N", style="grey0")
        
        # Add data, order corresponds to columns
        table.add_row(
            f"{self.search_idx.target}",
            str(self.low_max),
            str(self.high_min),
            f"[{low_buffer}|{high_buffer}]", 
            str(self.n),
        )

        # get string representation from console
        console = Console()
        with console.capture() as capture:
            console.print(table)
        str_output = capture.get()
        return str_output

    def __repr__(self):
        """Make this the same as the str representation"""
        return self.__str__()
        

Display

Code
class Step():
    """Store and render the status of an intermediate search step"""
    def __init__(self, a, b, names=list("ab")):
        # Create deep copies so subsequent updates don't affect this snapshot
        self.a = deepcopy(a)
        self.b = deepcopy(b)
        self.names = names

    @property
    def table(self):
        """Create a table with the import values"""
        ol_list = [self.a, self.b]
        # Create table with no default borders
        table = Table(show_header=True, header_style="bold purple4", safe_box=True)
        
        # Add columns
        table.add_column("List", style="bright_black", justify="center")
        table.add_column("Split", style="dodger_blue2", justify="center")
        table.add_column("max low", style="grey30", justify="right")
        table.add_column("min high", style="grey11")
        table.add_column("Split Location", style="dodger_blue2", justify="left")
        table.add_column("N", style="grey0")
        for i, ol in zip(self.names, ol_list):
            # Add data
            t = ol.search_idx.target
            n = ol.n
            w = 20
            low_width = int(t/n * w)
            low_buffer = '-' * low_width
            
            high_width = w - low_width
            high_buffer = '=' * high_width
            
            table.add_row(
                i,
                f"{ol.search_idx.target}",
                str(ol.low_max),
                str(ol.high_min),
                f"[{low_buffer}|{high_buffer}]", 
                str(ol.n),
            )
        console = Console()
        
        with console.capture() as capture:
            console.print(table)
        str_output = capture.get()
        return str_output

    def __repr__(self):
        return self.table
    
def highlight_ol(ol: OrderedList):
    """Use pandas styling to display the lower and upper halves of the list"""
    # lower
    green_cols = list(range(ol.search_idx.low, ol.search_idx.target))
    # upper
    blue_cols = list(range(ol.search_idx.target, ol.search_idx.high))
    
    s = (
        pd.Series(ol.padded_list).to_frame().T
        .style
        .set_properties(subset=green_cols, **{"background-color": "lightgreen"})
        .set_properties(subset=blue_cols, **{"background-color": "lightblue"})
    )
    return s

Algo

Code
def find_median(a, b, target_index, steps=None):
    """Find the median by finding the appropriate split

    Check the split requirements:
    If the split is right (a.low_max <= b.high_min) & (b.low_max <= a.high_min)

    Other wise the split is too high or too low: adjust the search target and 
    boundaries appropriately and try again.
    """
    if isinstance(steps, list):
        steps += [Step(a, b)]
        
    # check split requirements
    a_split_too_high = a.low_max > b.high_min
    a_split_too_low = b.low_max > a.high_min
    
    if a_split_too_high:
        # split was too high, search a_low for next target
        a.next_target(search_high=False)
        # adjust 'b' split according to 'a' split
        b.search_idx.target = target_index - a.search_idx.target
        # new search
        a, b = find_median(a, b, target_index, steps)
    elif a_split_too_low:
        # split was too low, search a_high for next target
        a.next_target(search_high=True)
        # adjust 'b' split according to 'a' split
        b.search_idx.target = target_index - a.search_idx.target
        # new search
        a, b = find_median(a, b, target_index, steps)
        
    return a, b
    
def calculate_median(a, b):
    """Given a correct split of lists a and b, calculate the overall median"""
    n = a.n + b.n
    low_max = max(a.low_max, b.low_max)
    high_min = min(a.high_min, b.high_min)
    if n % 2 == 0:
        med = (low_max + high_min) / 2.0
    else:
        med = high_min
        
    return med
    
def main(nums1: list[int], nums2: list[int], steps=None):
    """Find the median of the combined lists nums1 and nums2"""
    a, b = OrderedList(nums1), OrderedList(nums2)
    # make sure a contains the shorter list
    a, b = (a, b) if a.n <= b.n else (b, a)
    
    # adjust 'b' split according to 'a' split
    target_index = (a.n + b.n) // 2
    b.set_target(target_index - a.search_idx.target)

    # start search
    a, b = find_median(a, b, target_index, steps=steps)

    # calculate the median
    med = calculate_median(a, b)
    
    return a, b, med

Random data test

Code
def random_data_test():
    # first ordered list
    # random length from 1 to 100
    na = np.random.randint(1, 100)

    # random integers from 0 to 60
    a = np.random.randint(0, 60, na).tolist()
    a.sort()

    # second ordered list
    # random length from 1 to 100
    nb = np.random.randint(1, 100)
    # random integers from 40 to 100
    b = np.random.randint(40, 100, nb).tolist()
    b.sort()

    # sorted combined list
    c = np.concat([a, b]).astype(int).tolist()
    c.sort()
    median_calculated = np.median(c)
    
    record_steps = []
    a_final, b_final, median_algo = main(a, b, record_steps);

    output = dict(
        a_final=a_final,
        b_final=b_final,
        c=OrderedList(c),
        median_calculated=median_calculated,
        median_algo=median_algo,
        steps=record_steps,
    )

    return output

def evaluate_output(output):
    calculated_equal_algo = (
        'YES' if output['median_calculated'] == output['median_algo'] else 'NO'
    )
    print(
        textwrap.dedent(
            f"""
            Is the calculated media the same as the median from the algo?
            --> {calculated_equal_algo} <--
            
            - calculated median = {output['median_calculated']}
            - algo median = {output['median_algo']}
            """
        )
    )
    print("The combined sorted list:")
    display(output["c"])
    
    print("\nSearch steps:")
    for step in output["steps"]:
        display(step)
Code
evaluate_output(random_data_test())

Is the calculated media the same as the median from the algo?
--> YES <--

- calculated median = 49.5
- algo median = 49.5

The combined sorted list:

┏━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━┓
┃ Split  max low  min high  Split Location           N   ┃
┡━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━┩
│  72         49  50        [----------|==========]  144 │
└───────┴─────────┴──────────┴─────────────────────────┴─────┘

Search steps:

┏━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━┓
┃ List  Split  max low  min high  Split Location           N  ┃
┡━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━┩
│  a     30         67  68        [---------|===========]  61 │
│  b     43         36  38        [----------|==========]  85 │
└──────┴───────┴─────────┴──────────┴─────────────────────────┴────┘

┏━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━┓
┃ List  Split  max low  min high  Split Location           N  ┃
┡━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━┩
│  a     15         55  57        [----|================]  61 │
│  b     58         45  45        [-------------|=======]  85 │
└──────┴───────┴─────────┴──────────┴─────────────────────────┴────┘

┏━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━┓
┃ List  Split  max low  min high  Split Location           N  ┃
┡━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━┩
│  a      7         50  50        [--|==================]  61 │
│  b     66         49  49        [---------------|=====]  85 │
└──────┴───────┴─────────┴──────────┴─────────────────────────┴────┘

┏━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━┓
┃ List  Split  max low  min high  Split Location           N  ┃
┡━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━┩
│  a      3         42  43        [|====================]  61 │
│  b     70         51  51        [----------------|====]  85 │
└──────┴───────┴─────────┴──────────┴─────────────────────────┴────┘

┏━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━┓
┃ List  Split  max low  min high  Split Location           N  ┃
┡━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━┩
│  a      5         47  48        [-|===================]  61 │
│  b     68         50  50        [----------------|====]  85 │
└──────┴───────┴─────────┴──────────┴─────────────────────────┴────┘

┏━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━┓
┃ List  Split  max low  min high  Split Location           N  ┃
┡━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━┩
│  a      6         48  50        [-|===================]  61 │
│  b     67         49  50        [---------------|=====]  85 │
└──────┴───────┴─────────┴──────────┴─────────────────────────┴────┘

Specific tests

Code
nums1 = []
nums2 = [3, 4, 5]

record_steps = []
a, b, med = main(nums1, nums2, record_steps)

assert(med == 4)
print("Success")
Success
Code
nums1 = [1, 2, 3]
nums2 = [40, 50]

record_steps = []
a, b, med = main(nums1, nums2, record_steps)

assert(med == 3)
print("Success")
Success
Code
nums1 = [10, 20]
nums2 = [40, 50]

record_steps = []
a, b, med = main(nums1, nums2, record_steps)

assert(med == 30)
print("Success")
Success